{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1253ac2f-6750-43c7-827c-ebe0ed0499b0",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.2.0.dev20230916+cu121\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "print(torch.__version__)\n",
    "torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "54a3ad37-f151-43d9-b555-fcd8deddc0e9",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.1.99\n"
     ]
    }
   ],
   "source": [
    "import sentencepiece as spm\n",
    "print(spm.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f8ab86d2-8d7e-4af3-afbf-3fb6fa663988",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train_model(fname, prefix):\n",
    "    spm.SentencePieceTrainer.train(input=fname, model_prefix=prefix, vocab_size=16000)\n",
    "    \n",
    "corpus = \"bird_shooter.txt\"\n",
    "prefix = \"bird_shooter\"\n",
    "train_model(corpus, prefix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "abf3510e-cb35-4854-9a76-a139d9f553f4",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data split of train has 505009 tokens\n",
      "data split of test has 58877 tokens\n"
     ]
    }
   ],
   "source": [
    "def load_tokenizer(model_file):\n",
    "    sp = spm.SentencePieceProcessor()\n",
    "    if not sp.load(model_file=model_file):\n",
    "        return False, None\n",
    "    else:\n",
    "        return True, sp\n",
    "\n",
    "def load_file_into_splits(text_file, split_ratio):\n",
    "    with open(text_file, 'r') as file:\n",
    "        data = file.read()\n",
    "    split_idx = int(len(data) * split_ratio)\n",
    "    return data[:split_idx], data[split_idx:]\n",
    "\n",
    "import numpy as np\n",
    "def encode_and_save(sp, content, prefix):\n",
    "    token_ids = sp.encode(content, out_type=int)\n",
    "    print(f\"data split of {prefix} has {len(token_ids)} tokens\")\n",
    "    token_ids = np.array(token_ids, dtype=np.int32)\n",
    "    token_ids.tofile(\"{}.dat\".format(prefix))\n",
    "    \n",
    "import sys\n",
    "def gen_dataset(text_file, model_file):\n",
    "    flag, sp = load_tokenizer(model_file)\n",
    "    if not flag:\n",
    "        print(f\"load tokenizer model from: {model_file} failed\")\n",
    "        sys.exit(1)\n",
    "    split_ratio = 0.9\n",
    "    train_text, test_text = load_file_into_splits(text_file, split_ratio)\n",
    "    encode_and_save(sp, train_text, \"train\")\n",
    "    encode_and_save(sp, test_text, \"test\")\n",
    "    \n",
    "gen_dataset(corpus, prefix+\".model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d1f4331e-9e16-42ca-8f57-77785d2095d8",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "features:  他不会不给六王爷的面子。”完颜洪烈道\n",
      "targets:  不会不给六王爷的面子。”完颜洪烈道:“\n",
      "features:  女儿家的痴情呆想,这人哪里\n",
      "targets:  家的痴情呆想,这人哪里是甚么\n",
      "features:  得更低了。完颜康心中一荡,伸出左臂\n",
      "targets:  更低了。完颜康心中一荡,伸出左臂去\n",
      "features:  呢”郭靖道,“我接她到桃花岛上\n",
      "targets:  ”郭靖道,“我接她到桃花岛上住\n"
     ]
    }
   ],
   "source": [
    "def get_batch(data, batch_size=4):\n",
    "    win_len = 10\n",
    "    ix = torch.randint(len(data)-win_len, (batch_size,))\n",
    "    x = np.stack([data[i:i+win_len] for i in ix])\n",
    "    y = np.stack([data[i+1:i+1+win_len] for i in ix])\n",
    "    return x, y\n",
    "\n",
    "model_file = prefix + \".model\"\n",
    "\n",
    "def gen_samples(fname):\n",
    "    train_data = np.memmap(fname, dtype=np.int32, mode='r')\n",
    "    x, y = get_batch(train_data)\n",
    "    \n",
    "    flag, sp = load_tokenizer(model_file)\n",
    "    if not flag:\n",
    "        print(f\"load tokenizer model from: {model_file} failed\")\n",
    "        sys.exit(1)\n",
    "        \n",
    "    for features, targets in zip(x, y):\n",
    "        print(\"features: \", sp.decode(features.tolist()))\n",
    "        print(\"targets: \", sp.decode(targets.tolist()))\n",
    "\n",
    "gen_samples(\"train.dat\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
